package com.shujia.spark.streaming

import org.apache.kafka.clients.consumer.ConsumerRecord
import org.apache.kafka.common.serialization.StringDeserializer
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
import org.apache.spark.streaming.dstream.InputDStream
import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe
import org.apache.spark.streaming.kafka010.KafkaUtils
import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent
import org.apache.spark.streaming.{Durations, StreamingContext}

object Demo8BlackFilter {
  def main(args: Array[String]): Unit = {
    /**
      * 动态修改广播变量的数据
      */
    val spark: SparkSession = SparkSession
      .builder()
      .appName("black")
      .master("local[2]")
      .getOrCreate()
    import spark.implicits._

    val ssc = new StreamingContext(spark.sparkContext, Durations.seconds(5))

    val kafkaParams: Map[String, Object] = Map[String, Object](
      "bootstrap.servers" -> "master:9092,node1:9092,node2:9092",
      "key.deserializer" -> classOf[StringDeserializer],
      "value.deserializer" -> classOf[StringDeserializer],
      "group.id" -> "asdasdasdas",
      "auto.offset.reset" -> "latest", //latest：读取新的数据
      "enable.auto.commit" -> "false"
    )

    //topic 列表
    val topics = Array("dianxin")

    val linesDS: InputDStream[ConsumerRecord[String, String]] = KafkaUtils.createDirectStream[String, String](
      ssc,
      PreferConsistent,
      Subscribe[String, String](topics, kafkaParams)
    )

    /**
      * driver端，只运行一次
      */
    println("foreacRDD外面")

    linesDS.foreachRDD(foreachFunc = rdd => {
      /**
        * driver端，每隔batch会运行一次
        */
      println("foreachRDD内部，算子外部")

      /**
        * 读取黑名单中的数据
        */

      val blackListDF: DataFrame = spark.read
        .format("jdbc")
        .option("url", "jdbc:mysql://master:3306")
        .option("dbtable", "student.t_blacklist")
        .option("user", "root")
        .option("password", "123456")
        .load()

      //黑名单的列表
      val blackList: Array[String] = blackListDF.as[String].collect()

      println(blackList.mkString(","))

      //将黑名单广播
      val broadCastBlackList: Broadcast[Array[String]] = spark.sparkContext.broadcast(blackList)

      val blackRDD: RDD[ConsumerRecord[String, String]] = rdd.filter(record => {
        val value: String = record.value()
        val mdn: String = value.split(",")(0)

        //获取广播变量
        val blackListValue: Array[String] = broadCastBlackList.value
        blackListValue.contains(mdn)
      })

      //将数据保存到mysql中
      blackRDD
        .map(_.value())
        .toDF("mdn")
        .write
        .mode(SaveMode.Append)
        .format("jdbc")
        .option("url", "jdbc:mysql://master:3306")
        .option("dbtable", "student.dianxin_black")
        .option("user", "root")
        .option("password", "123456")
        .save()

      //清除广播变量

    })
    ssc.start()
    ssc.awaitTermination()
    ssc.stop()
  }
}
